import random
import torch

from .cache import getCtRawCandidate
from .luna_2d_seg import Luna2dSegmentationDataset

# Instead of the full CT slices, we're going to train on 64x64 crops around
# our positive candidates (the actually-a-nodule candidates). These 64x64
# patches will be taken randomly from a 96x96 crop centered on the nodule. We
# will also include three slices of context in both directions as additional
# "channels" to our 2D segmentation.
#
# We're doing this to make training more stable, and to converge more quickly.
# The only reason we know to do this is because we tried to train on whole CT
# slices, but we found the results unsatisfactory. After some experimentation,
# we found that the 64x64 semirandom crop approach worked well, so we decided
# to use that.
#
# We believe the whole-slice training was unstable essentially due to a
# class-balancing issue. Since each nodule is so small compared to the whole CT
# slice, we were right back in a needle-in-a-haystack situation similar to the
# one we got out of in the last chapter, where our positive samples were
# swamped by the negatives. In this case, we're talking about pixels rather
# than nodules, but the concept is the same. By training on crops, we're
# keeping the number of positive pixels the same and reducing the negative
# pixel count by several orders of magnitude.
#
# Because our segmentation model is pixel-to-pixel and takes images of arbitrary
# size, we can get away with training and validating on samples with different
# dimensions. Validation uses the same convolutions with the same weights, just
# applied to a larger set of pixels (and so with fewer border pixels to fill in
# with edge data).
#
# One caveat to this approach is that since our validation set contains orders
# of magnitude more negative pixels, our model will have a huge false positive
# rate during validation. There are many more opportunities for our segmentation
# model to get tricked! It doesn't help that we're going to be pushing for high
# recall as well.


class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ratio_int = 2

    def __len__(self):
        return 300000

    def shuffleSamples(self):
        random.shuffle(self.candidateInfo_list)
        random.shuffle(self.pos_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
        return self.getitem_trainingCrop(candidateInfo_tup)

    def getitem_trainingCrop(self, candidateInfo_tup):
        # We limit our pos_a to the center slice that we're actually segmenting,
        # and then construct our 64x64 random crops of the 96x96
        ct_a, pos_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            (7, 96, 96),
        )

        # Taking a one-element slice keeps the third dimension, which will be
        # the (single) output channel.
        pos_a = pos_a[3:4]

        row_offset = random.randrange(0, 32)
        col_offset = random.randrange(0, 32)

        # fmt: off
        ct_t = torch.from_numpy(ct_a[:, row_offset : row_offset + 64, col_offset : col_offset + 64]).to(torch.float32)
        pos_t = torch.from_numpy(pos_a[:, row_offset : row_offset + 64, col_offset : col_offset + 64]).to(torch.long)
        # fmt: on

        slice_ndx = center_irc.index

        return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx
